{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "from path_explain.utils import set_up_environment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "set_up_environment(visible_devices='2')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 10\n",
    "test_size  = (32, 64)\n",
    "data = np.random.randn(batch_size, *test_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = tf.keras.models.Sequential()\n",
    "model.add(tf.keras.layers.Input(shape=test_size))\n",
    "model.add(tf.keras.layers.Conv1D(filters=16, kernel_size=3, strides=1, activation=tf.keras.activations.softplus))\n",
    "model.add(tf.keras.layers.Flatten())\n",
    "model.add(tf.keras.layers.Dense(units=32, activation=tf.keras.activations.softplus))\n",
    "model.add(tf.keras.layers.Dense(units=1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_input = tf.convert_to_tensor(data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [],
   "source": [
    "with tf.GradientTape() as second_order_tape:\n",
    "    second_order_tape.watch(batch_input)\n",
    "    with tf.GradientTape() as first_order_tape:\n",
    "        first_order_tape.watch(batch_input)\n",
    "        output = model(batch_input)\n",
    "    gradient = first_order_tape.gradient(output, batch_input)\n",
    "\n",
    "hessian = second_order_tape.batch_jacobian(gradient, batch_input)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [],
   "source": [
    "input_times_hessian = hessian * batch_input[:, :, :, tf.newaxis, tf.newaxis] * batch_input[:, tf.newaxis, tf.newaxis, :, :]\n",
    "summed_hessian_from_jacobian = tf.reduce_sum(input_times_hessian, axis=(2, 4))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 81,
   "metadata": {},
   "outputs": [],
   "source": [
    "with tf.GradientTape(persistent=True) as second_order_tape:\n",
    "    second_order_tape.watch(batch_input)\n",
    "    with tf.GradientTape() as first_order_tape:\n",
    "        first_order_tape.watch(batch_input)\n",
    "        output = model(batch_input)\n",
    "    gradient = first_order_tape.gradient(output, batch_input)\n",
    "    gradient_times_input = batch_input * gradient\n",
    "    summed_gradient_times_input = tf.reduce_sum(gradient_times_input, axis=2)\n",
    "batch_summed_hessian = second_order_tape.batch_jacobian(summed_gradient_times_input, batch_input)\n",
    "batch_summed_hessian_times_input = batch_summed_hessian * batch_input[:, tf.newaxis, :, :]\n",
    "batch_summed_summed_hessian_times_input = tf.reduce_sum(batch_summed_hessian_times_input, axis=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 85,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "TensorShape([10, 32, 32, 64])"
      ]
     },
     "execution_count": 85,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "batch_summed_hessian.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 82,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: shape=(10, 32, 32), dtype=float64, numpy=\n",
       "array([[[-1.05502390e-01, -1.70172225e-02, -5.44662466e-02, ...,\n",
       "         -1.33994116e-03,  2.69511727e-03, -6.65550233e-04],\n",
       "        [-1.70172197e-02,  9.86773240e-02,  5.66338611e-02, ...,\n",
       "         -1.95719531e-03,  4.48519505e-04,  3.24663182e-05],\n",
       "        [-5.44662534e-02,  5.66338586e-02, -3.22873205e-01, ...,\n",
       "          5.97984305e-03,  1.52160315e-03,  3.36281641e-03],\n",
       "        ...,\n",
       "        [-1.33994129e-03, -1.95719384e-03,  5.97984041e-03, ...,\n",
       "         -4.31650734e-02, -1.94025766e-02, -1.09749889e-02],\n",
       "        [ 2.69511760e-03,  4.48518748e-04,  1.52160548e-03, ...,\n",
       "         -1.94025774e-02,  6.06804017e-03, -4.90592735e-02],\n",
       "        [-6.65548825e-04,  3.24665108e-05,  3.36281812e-03, ...,\n",
       "         -1.09749909e-02, -4.90592671e-02,  1.43230623e-01]],\n",
       "\n",
       "       [[ 4.98012997e-02,  1.09805460e-02,  2.28644936e-03, ...,\n",
       "         -2.48018607e-04, -4.47010343e-04, -7.25157506e-04],\n",
       "        [ 1.09805466e-02, -1.24665227e-02, -1.21603710e-02, ...,\n",
       "          3.14699487e-03, -2.92003669e-04,  3.58036839e-04],\n",
       "        [ 2.28644707e-03, -1.21603704e-02,  7.04566939e-02, ...,\n",
       "         -2.68601151e-03,  1.43370567e-03, -1.09785838e-03],\n",
       "        ...,\n",
       "        [-2.48019115e-04,  3.14699687e-03, -2.68601177e-03, ...,\n",
       "          1.94545170e-02,  1.81494966e-02,  5.26090340e-03],\n",
       "        [-4.47010176e-04, -2.92003282e-04,  1.43370444e-03, ...,\n",
       "          1.81494983e-02,  8.63911111e-03,  6.67794107e-03],\n",
       "        [-7.25157716e-04,  3.58037542e-04, -1.09785831e-03, ...,\n",
       "          5.26090241e-03,  6.67793715e-03, -8.60635235e-02]],\n",
       "\n",
       "       [[ 1.77587699e-02,  7.96652104e-04,  1.36658610e-02, ...,\n",
       "          3.28900782e-03, -3.26259889e-03, -2.94892614e-03],\n",
       "        [ 7.96650731e-04,  1.10358255e-01,  1.98050046e-02, ...,\n",
       "          9.47455231e-03, -4.98312556e-03, -6.82798380e-03],\n",
       "        [ 1.36658639e-02,  1.98050206e-02, -2.95258563e-02, ...,\n",
       "          9.24705452e-03, -6.45727977e-03, -2.68087816e-03],\n",
       "        ...,\n",
       "        [ 3.28900661e-03,  9.47455071e-03,  9.24705195e-03, ...,\n",
       "         -6.72810231e-02, -1.79717212e-02,  9.00560317e-03],\n",
       "        [-3.26259845e-03, -4.98312413e-03, -6.45728118e-03, ...,\n",
       "         -1.79717186e-02,  7.21456950e-02,  2.89192576e-02],\n",
       "        [-2.94892639e-03, -6.82798249e-03, -2.68087727e-03, ...,\n",
       "          9.00560633e-03,  2.89192657e-02,  1.27660623e-01]],\n",
       "\n",
       "       ...,\n",
       "\n",
       "       [[ 6.39216996e-02, -1.11535326e-02, -4.72571948e-03, ...,\n",
       "         -6.17456292e-03, -4.52969125e-03, -5.18381734e-04],\n",
       "        [-1.11535349e-02,  9.29966505e-02,  4.80743556e-03, ...,\n",
       "          8.65158173e-03, -4.71348751e-03, -6.49619423e-03],\n",
       "        [-4.72572655e-03,  4.80744130e-03, -4.18128280e-02, ...,\n",
       "          5.90708116e-03,  1.28476352e-02,  3.61839882e-03],\n",
       "        ...,\n",
       "        [-6.17456249e-03,  8.65158051e-03,  5.90707640e-03, ...,\n",
       "          3.94465148e-01,  8.57550533e-03,  1.59396792e-02],\n",
       "        [-4.52969210e-03, -4.71349132e-03,  1.28476335e-02, ...,\n",
       "          8.57552112e-03,  2.14752285e-01,  2.29358503e-02],\n",
       "        [-5.18381496e-04, -6.49619629e-03,  3.61839883e-03, ...,\n",
       "          1.59396804e-02,  2.29358511e-02,  1.42110255e-01]],\n",
       "\n",
       "       [[ 9.45755016e-02,  1.69363971e-02,  2.06723117e-02, ...,\n",
       "          7.81347486e-03, -8.60906841e-03,  6.76591366e-03],\n",
       "        [ 1.69364004e-02, -1.10524152e-01, -8.10891610e-03, ...,\n",
       "         -1.79402923e-03,  1.69780705e-03, -1.03741305e-03],\n",
       "        [ 2.06723118e-02, -8.10892808e-03,  1.11100327e-01, ...,\n",
       "          5.11888107e-05, -3.68980838e-03,  6.84421199e-03],\n",
       "        ...,\n",
       "        [ 7.81347666e-03, -1.79402788e-03,  5.11917611e-05, ...,\n",
       "          3.74522113e-01,  4.10812556e-03,  1.97389088e-02],\n",
       "        [-8.60907141e-03,  1.69780496e-03, -3.68981203e-03, ...,\n",
       "          4.10812893e-03, -8.42574885e-02, -3.38535261e-02],\n",
       "        [ 6.76591396e-03, -1.03741121e-03,  6.84421207e-03, ...,\n",
       "          1.97389140e-02, -3.38535284e-02, -4.01721498e-02]],\n",
       "\n",
       "       [[-1.99715363e-02,  1.40825762e-02, -2.72488778e-02, ...,\n",
       "         -2.11194274e-03, -4.71829508e-04, -4.22506541e-03],\n",
       "        [ 1.40825732e-02, -1.50310251e-01,  2.38326213e-02, ...,\n",
       "         -1.48941848e-02,  2.23857081e-03, -3.92097751e-03],\n",
       "        [-2.72488752e-02,  2.38326156e-02, -1.15096649e-01, ...,\n",
       "         -2.59511620e-02,  8.37400503e-03,  5.47339113e-03],\n",
       "        ...,\n",
       "        [-2.11194255e-03, -1.48941894e-02, -2.59511657e-02, ...,\n",
       "         -2.13244627e-02, -7.03517876e-04, -6.20344137e-03],\n",
       "        [-4.71828861e-04,  2.23857211e-03,  8.37400403e-03, ...,\n",
       "         -7.03522974e-04,  4.36545053e-02,  3.80307407e-02],\n",
       "        [-4.22506575e-03, -3.92097682e-03,  5.47339182e-03, ...,\n",
       "         -6.20344092e-03,  3.80307388e-02, -9.13110889e-02]]])>"
      ]
     },
     "execution_count": 82,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "batch_summed_summed_hessian_times_input"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
